package com.isyscore.os.etl.utils;

import com.isyscore.os.core.exception.DataFactoryException;
import com.isyscore.os.core.exception.ErrorCode;
import com.isyscore.os.etl.model.SqlCommandCall;
import com.isyscore.os.etl.model.enums.SqlCommand;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.config.Lex;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlConformance;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.sql.parser.validate.FlinkSqlConformance;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.*;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.config.TableConfigOptions;
import org.apache.flink.table.planner.calcite.CalciteConfig;
import org.apache.flink.table.planner.calcite.CalciteParser;
import org.apache.flink.table.planner.delegation.FlinkSqlParserFactories;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.TableConfigUtils;

import java.util.*;
import java.util.regex.Matcher;


@Slf4j
public class SqlValidationUtils {
    public static void preCheckSql(List<String> sql) {

        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        EnvironmentSettings settings = EnvironmentSettings.newInstance()
                .useBlinkPlanner()
                .inStreamingMode()
                .build();

        TableEnvironment tEnv = StreamTableEnvironment.create(env, settings);

        List<SqlCommandCall> sqlCommandCallList = fileToSql(sql);
        if (CollectionUtils.isEmpty(sqlCommandCallList)) {
            log.error("没解析出sql，请检查语句，如: 缺少分号");
            throw new DataFactoryException(ErrorCode.VALIDATE_ERROR, "没解析出sql，请检查语句，如: 缺少分号");
        }

        TableConfig config = tEnv.getConfig();
        String value = null;

        boolean isInsertSql = false;

        boolean isSelectSql = false;
        try {
            for (SqlCommandCall sqlCommandCall : sqlCommandCallList) {

                value = sqlCommandCall.operands[0];

                //配置
                if (sqlCommandCall.sqlCommand == SqlCommand.SET) {
                    String key = sqlCommandCall.operands[0];
                    String val = sqlCommandCall.operands[1];
                    if (val.contains("\n")) {
                        log.error("set语法值异常：" + val);
                        throw new DataFactoryException(ErrorCode.VALIDATE_ERROR, "set语法值异常");
                    }
                    if (TableConfigOptions.TABLE_SQL_DIALECT.key().equalsIgnoreCase(key.trim())
                            && SqlDialect.HIVE.name().equalsIgnoreCase(val.trim())) {
                        config.setSqlDialect(SqlDialect.HIVE);
                    } else {
                        config.setSqlDialect(SqlDialect.DEFAULT);
                    }

                    //其他
                } else {
                    if (SqlCommand.INSERT_INTO.equals(sqlCommandCall.sqlCommand)
                            || SqlCommand.INSERT_OVERWRITE.equals(sqlCommandCall.sqlCommand)) {
                        isInsertSql = true;
                    }
                    if (SqlCommand.SELECT.equals(sqlCommandCall.sqlCommand)) {
                        isSelectSql = true;
                    }
                    CalciteParser parser = new CalciteParser(getSqlParserConfig(config));
                    parser.parse(sqlCommandCall.operands[0]);
                }
            }
        } catch (Exception e) {
            log.error("语法异常：sql={}, 原因是: {}", value, e);
            throw new DataFactoryException(ErrorCode.VALIDATE_ERROR_NO_WORDS, e.getMessage());
        }
        if (!isInsertSql) {
            log.error("必须包含insert或者insert overwrite语句");
            throw new DataFactoryException(ErrorCode.VALIDATE_ERROR, "必须包含insert或者insert overwrite语句");
        }
        if (isSelectSql) {
            log.error("暂时不支持直接使用select语句，请使用insert into select语法");
            throw new DataFactoryException(ErrorCode.VALIDATE_ERROR, "暂时不支持直接使用select语句，请使用insert into select语法");
        }
    }

    private static SqlParser.Config getSqlParserConfig(TableConfig tableConfig) {
        return JavaScalaConversionUtil.toJava(getCalciteConfig(tableConfig).getSqlParserConfig()).orElseGet(
                () -> {
                    SqlConformance conformance = getSqlConformance(tableConfig.getSqlDialect());
                    return SqlParser
                            .config()
                            .withParserFactory(FlinkSqlParserFactories.create(conformance))
                            .withConformance(conformance)
                            .withLex(Lex.JAVA)
                            .withIdentifierMaxLength(256);
                }
        );
    }

    private static CalciteConfig getCalciteConfig(TableConfig tableConfig) {
        return TableConfigUtils.getCalciteConfig(tableConfig);
    }

    private static FlinkSqlConformance getSqlConformance(SqlDialect sqlDialect) {
        switch (sqlDialect) {
            case HIVE:
                return FlinkSqlConformance.HIVE;
            case DEFAULT:
                return FlinkSqlConformance.DEFAULT;
            default:
                log.error("Unsupported SQL dialect: " + sqlDialect);
                throw new DataFactoryException(ErrorCode.DOES_NOT_SUPPORT);
        }
    }

    /**
     * 字符串转sql
     */
    public static List<String> toSqlList(String sql) {
        if (StringUtils.isEmpty(sql)) {
            return Collections.emptyList();
        }
        return Arrays.asList(sql.split("\n"));
    }

    public static List<SqlCommandCall> fileToSql(List<String> lineList) {

        if (CollectionUtils.isEmpty(lineList)) {
            throw new DataFactoryException(ErrorCode.MISSING_SERVLET_REQUEST_PARAMETER);
        }

        List<SqlCommandCall> sqlCommandCallList = new ArrayList<>();

        StringBuilder stmt = new StringBuilder();

        for (String line : lineList) {
            //开头是 -- 的表示注释
            if (line.trim().isEmpty() || line.startsWith("--") ||
                    trimStart(line).startsWith("--")) {
                continue;
            }
            stmt.append("\n").append(line);
            if (line.trim().endsWith(";")) {
                Optional<SqlCommandCall> optionalCall = parse(stmt.toString());
                if (optionalCall.isPresent()) {
                    sqlCommandCallList.add(optionalCall.get());
                } else {
                    log.error("不支持该语法使用: " + stmt + "'");
                    throw new DataFactoryException(ErrorCode.FLINK_SQL_ERROR, stmt);
                }
                stmt.setLength(0);
            }
        }

        return sqlCommandCallList;

    }


    private static Optional<SqlCommandCall> parse(String stmt) {
        stmt = stmt.trim();
        if (stmt.endsWith(";")) {
            stmt = stmt.substring(0, stmt.length() - 1).trim();
        }
        for (SqlCommand cmd : SqlCommand.values()) {
            final Matcher matcher = cmd.getPattern().matcher(stmt);
            if (matcher.matches()) {
                final String[] groups = new String[matcher.groupCount()];
                for (int i = 0; i < groups.length; i++) {
                    groups[i] = matcher.group(i + 1);
                }
                return cmd.getOperandConverter().apply(groups)
                        .map((operands) -> new SqlCommandCall(cmd, operands));
            }
        }
        return Optional.empty();
    }


    private static String trimStart(String str) {
        if (StringUtils.isEmpty(str)) {
            return str;
        }
        final char[] value = str.toCharArray();

        int start = 0, last = str.length() - 1;
        int end = last;
        while ((start <= end) && (value[start] <= ' ')) {
            start++;
        }
        if (start == 0 && end == last) {
            return str;
        }
        if (start >= end) {
            return "";
        }
        return str.substring(start, end);
    }
}
